{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sAgUgR5Mzzz2"
   },
   "source": [
    "# XLA in Python\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/XLA_in_Python.ipynb)\n",
    "\n",
    "<img style=\"height:100px;\" src=\"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/compiler/xla/g3doc/images/xlalogo.png\"> <img style=\"height:100px;\" src=\"https://upload.wikimedia.org/wikipedia/commons/c/c3/Python-logo-notext.svg\">\n",
    "\n",
    "_Anselm Levskaya_, _Qiao Zhang_\n",
    "\n",
    "XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and will soon use for all devices, so it's worth some study.  However, it's not exactly easy to play with XLA computations directly using the raw C++ interface.  JAX exposes the underlying XLA computation builder API through a python wrapper, and makes interacting with the XLA compute model accessible for messing around and prototyping.\n",
    "\n",
    "XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.).  \n",
    "\n",
    "As end users we interact with the computational primitives offered to us by the HLO spec.\n",
    "\n",
    "**Caution:  This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EZK5RseuvZkr"
   },
   "source": [
    "## References \n",
    "\n",
    "__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
    "\n",
    "https://www.tensorflow.org/xla/operation_semantics\n",
    "\n",
    "more details on ops in the source code.\n",
    "\n",
    "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
    "\n",
    "__python xla client__:  this is the XLA python client for JAX, and what we're using here.\n",
    "\n",
    "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
    "\n",
    "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
    "\n",
    "__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
    "\n",
    "https://github.com/google/jax/blob/master/jax/lax.py\n",
    "\n",
    "https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py\n",
    "\n",
    "https://github.com/google/jax/blob/master/jax/interpreters/xla.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3XR2NGmrzBGe"
   },
   "source": [
    "## Colab Setup and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Ogo2SBd3u18P"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# We only need to import JAX's xla_client, not all of JAX.\n",
    "from jax.lib import xla_client as xc\n",
    "xops = xc.ops\n",
    "\n",
    "# Plotting\n",
    "import matplotlib as mpl\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib import gridspec\n",
    "from matplotlib import rcParams\n",
    "rcParams['image.interpolation'] = 'nearest'\n",
    "rcParams['image.cmap'] = 'viridis'\n",
    "rcParams['axes.grid'] = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "odmjXyhMuNJ5"
   },
   "source": [
    "## Simple Computations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "UYUtxVzMYIiv",
    "outputId": "5c603ab4-0295-472c-b462-9928b2a9520d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(0.14112, dtype=float32)"
      ]
     },
     "execution_count": 2,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# make a computation builder\n",
    "c = xc.XlaBuilder(\"simple_scalar\")\n",
    "\n",
    "# define a parameter shape and parameter\n",
    "param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
    "x = xops.Parameter(c, 0, param_shape)\n",
    "\n",
    "# define computation graph\n",
    "y = xops.Sin(x)\n",
    "\n",
    "# build computation graph\n",
    "# Keep in mind that incorrectly constructed graphs can cause \n",
    "# your notebook kernel to crash!\n",
    "computation = c.Build()\n",
    "\n",
    "# get a cpu backend\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(computation)\n",
    "\n",
    "# define a host variable with above parameter shape\n",
    "host_input = np.array(3.0, dtype=np.float32)\n",
    "\n",
    "# place host variable on device and execute\n",
    "device_input = cpu_backend.buffer_from_pyval(host_input)\n",
    "device_out = compiled_computation.execute([device_input ,])\n",
    "\n",
    "# retrive the result\n",
    "device_out[0].to_py()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "rIA-IVMVvQs2",
    "outputId": "a4d8ef32-43f3-4a48-f732-e85e158b602e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.14112  , 0.7568025, 0.9589243], dtype=float32)"
      ]
     },
     "execution_count": 3,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# same as above with vector type:\n",
    "\n",
    "c = xc.XlaBuilder(\"simple_vector\")\n",
    "param_shape = xc.Shape.array_shape(np.dtype(np.float32), (3,))\n",
    "x = xops.Parameter(c, 0, param_shape)\n",
    "\n",
    "# chain steps by reference:\n",
    "y = xops.Sin(x)\n",
    "z = xops.Abs(y)\n",
    "computation = c.Build()\n",
    "\n",
    "# get a cpu backend\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(computation)\n",
    "\n",
    "host_input = np.array([3.0, 4.0, 5.0], dtype=np.float32)\n",
    "\n",
    "device_input = cpu_backend.buffer_from_pyval(host_input)\n",
    "device_out = compiled_computation.execute([device_input ,])\n",
    "\n",
    "# retrive the result\n",
    "device_out[0].to_py()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F8kWlLaVuQ1b"
   },
   "source": [
    "## Simple While Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "MDQP1qW515Ao",
    "outputId": "53245817-b5fb-4285-ee62-7eb33a822be4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(0, dtype=int32)"
      ]
     },
     "execution_count": 4,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# trivial while loop, decrement until 0\n",
    "#   x = 5\n",
    "#   while x > 0:\n",
    "#     x = x - 1\n",
    "#\n",
    "in_shape = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
    "\n",
    "# body computation:\n",
    "bcb = xc.XlaBuilder(\"bodycomp\")\n",
    "x = xops.Parameter(bcb, 0, in_shape)\n",
    "const1 = xops.Constant(bcb, np.int32(1))\n",
    "y = xops.Sub(x, const1)\n",
    "body_computation = bcb.Build()\n",
    "\n",
    "# test computation:\n",
    "tcb = xc.XlaBuilder(\"testcomp\")\n",
    "x = xops.Parameter(tcb, 0, in_shape)\n",
    "const0 = xops.Constant(tcb, np.int32(0))\n",
    "y = xops.Gt(x, const0)\n",
    "test_computation = tcb.Build()\n",
    "\n",
    "# while computation:\n",
    "wcb = xc.XlaBuilder(\"whilecomp\")\n",
    "x = xops.Parameter(wcb, 0, in_shape)\n",
    "xops.While(test_computation, body_computation, x)\n",
    "while_computation = wcb.Build()\n",
    "\n",
    "# Now compile and execute:\n",
    "# get a cpu backend\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(while_computation)\n",
    "\n",
    "host_input = np.array(5, dtype=np.int32)\n",
    "\n",
    "device_input = cpu_backend.buffer_from_pyval(host_input)\n",
    "device_out = compiled_computation.execute([device_input ,])\n",
    "\n",
    "# retrive the result\n",
    "device_out[0].to_py()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7UOnXlY8slI6"
   },
   "source": [
    "## While loops w/ Tuples - Newton's Method for sqrt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "HEWz-vzd6QPR",
    "outputId": "ad4c4247-8e81-4739-866f-2950fec5e759"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square root of 2.0 is 1.4142156839370728\n"
     ]
    }
   ],
   "source": [
    "Xsqr = 2\n",
    "guess = 1.0\n",
    "converged_delta = 0.001\n",
    "maxit = 1000\n",
    "\n",
    "in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
    "in_shape_1 = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
    "in_shape_2 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
    "in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1, in_shape_2])\n",
    "\n",
    "# body computation:\n",
    "# x_{i+1} = x_i - (x_i**2 - y) / (2 * x_i)\n",
    "bcb = xc.XlaBuilder(\"bodycomp\")\n",
    "intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
    "y = xops.GetTupleElement(intuple, 0)\n",
    "x = xops.GetTupleElement(intuple, 1)\n",
    "guard_cntr = xops.GetTupleElement(intuple, 2)\n",
    "new_x = xops.Sub(x, xops.Div(xops.Sub(xops.Mul(x, x), y), xops.Add(x, x)))\n",
    "result = xops.Tuple(bcb, [y, new_x, xops.Sub(guard_cntr, xops.Constant(bcb, np.int32(1)))])\n",
    "body_computation = bcb.Build()\n",
    "\n",
    "# test computation -- convergence and max iteration test\n",
    "tcb = xc.XlaBuilder(\"testcomp\")\n",
    "intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
    "y = xops.GetTupleElement(intuple, 0)\n",
    "x = xops.GetTupleElement(intuple, 1)\n",
    "guard_cntr = xops.GetTupleElement(intuple, 2)\n",
    "criterion = xops.Abs(xops.Sub(xops.Mul(x, x), y))\n",
    "# stop at convergence criteria or too many iterations\n",
    "test = xops.And(xops.Gt(criterion, xops.Constant(tcb, np.float32(converged_delta))), \n",
    "               xops.Gt(guard_cntr, xops.Constant(tcb, np.int32(0))))\n",
    "test_computation = tcb.Build()\n",
    "\n",
    "# while computation:\n",
    "# since jax does not allow users to create a tuple input directly, we need to\n",
    "# take multiple parameters and make a intermediate tuple before feeding it as\n",
    "# an initial carry to while loop\n",
    "wcb = xc.XlaBuilder(\"whilecomp\")\n",
    "y = xops.Parameter(wcb, 0, in_shape_0)\n",
    "x = xops.Parameter(wcb, 1, in_shape_1)\n",
    "guard_cntr = xops.Parameter(wcb, 2, in_shape_2)\n",
    "tuple_init_carry = xops.Tuple(wcb, [y, x, guard_cntr])\n",
    "xops.While(test_computation, body_computation, tuple_init_carry)\n",
    "while_computation = wcb.Build()\n",
    "\n",
    "# Now compile and execute:\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(while_computation)\n",
    "\n",
    "y = np.array(Xsqr, dtype=np.float32)\n",
    "x = np.array(guess, dtype=np.float32)\n",
    "maxit = np.array(maxit, dtype=np.int32)\n",
    "\n",
    "device_input_y = cpu_backend.buffer_from_pyval(y)\n",
    "device_input_x = cpu_backend.buffer_from_pyval(x)\n",
    "device_input_maxit = cpu_backend.buffer_from_pyval(maxit)\n",
    "device_out = compiled_computation.execute([device_input_y, device_input_x, device_input_maxit])\n",
    "\n",
    "# retrive the result\n",
    "print(\"square root of {y} is {x}\".format(y=y, x=device_out[1].to_py()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yETVIzTInFYr"
   },
   "source": [
    "## Calculate Symm Eigenvalues"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AiyR1e2NubKa"
   },
   "source": [
    "Let's exploit the XLA QR implementation to solve some eigenvalues for symmetric matrices.  \n",
    "\n",
    "This is the naive QR algorithm, without acceleration for closely-spaced eigenvalue convergence, nor any permutation to sort eigenvalues by magnitude."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "wjxDPbqCcuXT",
    "outputId": "2380db52-799d-494e-ded2-856e91f01b0f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sorted eigenvalues\n",
      "[-1.1406534  -0.5946617  -0.29557052 -0.09876542  0.07503236  0.19509281\n",
      "  0.47496718  0.858686    1.09709     5.281351  ]\n",
      "sorted eigenvalues from numpy\n",
      "[-1.140657   -0.5946614  -0.29557055 -0.09876533  0.07503222  0.19509293\n",
      "  0.47496703  0.85868585  1.0970895   5.2813535 ]\n",
      "sorted error\n",
      "[ 3.5762787e-06 -2.9802322e-07  2.9802322e-08 -8.9406967e-08\n",
      "  1.4156103e-07 -1.1920929e-07  1.4901161e-07  1.1920929e-07\n",
      "  4.7683716e-07 -2.3841858e-06]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK1klEQVR4nO3dUayeBX3H8e9v5xxmWx0osCW2ne0FcSEkDnNm0GZegFl0MrmZGSaY6U2zZCoaE4e78XYXxuiFc2lAbyByUbkghohL1Au3BDwUFqRVRyqDAoa6BSQdrKfw38U5M11Le56+fR+ec/58PwlJz3teHn6B8+V537fv+zRVhaQ+fmfqAZLmy6ilZoxaasaopWaMWmrGqKVmjFpqxqjfoJI8keSlJC8meT7Jvyb5myT+TGxx/gd8Y/uLqnoL8A7gH4C/A+6YdpIullGLqnqhqu4F/gr46yTXTL1JszNq/VZVPQgcA/506i2anVHrTM8Ab5t6hGZn1DrTTuC/ph6h2Rm1fivJn7AW9Y+n3qLZGbVI8ntJbgTuBu6sqken3qTZxc9TvzEleQL4A+AU8CpwGLgT+KeqemXCabpIRi0148NvqRmjlpoxaqkZo5aaWRzjoFe8baH27F6a+3F/cfTyuR9T2opefvl5Tq6eyGt9b5So9+xe4sH7d8/9uH/20U/M/ZjSVvTgI/94zu/58FtqxqilZoxaasaopWaMWmrGqKVmBkWd5INJfp7k8SS3jT1K0uw2jDrJAvB14EPA1cDHklw99jBJsxlypn4P8HhVHa2qk6x9kP6mcWdJmtWQqHcCT5329bH12/6fJPuTrCRZOf6ffsZemsrcXiirqgNVtVxVy1devjCvw0q6QEOifho4/Y3cu9Zvk7QJDYn6J8BVSfYmuQS4Gbh33FmSZrXhp7Sq6lSSTwH3AwvAN6vqsdGXSZrJoI9eVtV9wH0jb5E0B76jTGrGqKVmjFpqxqilZoxaamaUCw/+4ujlo1wkMP/yyNyPCVD7/niU40pT8EwtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTUzytVExzLWVT+Pv2v73I955b/999yPKQ3hmVpqxqilZoxaasaopWaMWmrGqKVmjFpqZsOok+xO8sMkh5M8luTW12OYpNkMefPJKeDzVXUoyVuAh5L8c1UdHnmbpBlseKauqmer6tD6r18EjgA7xx4maTYX9Jw6yR7gWuCB1/je/iQrSVZWV0/MZ52kCzY46iRvBr4DfLaqfnPm96vqQFUtV9Xy0tKOeW6UdAEGRZ1kibWg76qqe8adJOliDHn1O8AdwJGq+sr4kyRdjCFn6n3Ax4Hrkzyy/tefj7xL0ow2/C2tqvoxkNdhi6Q58B1lUjNGLTVj1FIzRi01s6UuPDiWMS4S+PxV2+Z+TIDL/v2lUY6rPjxTS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNeDXRkYx11c+lJ389ynFX//CKUY6r159naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZwVEnWUjycJLvjjlI0sW5kDP1rcCRsYZImo9BUSfZBXwYuH3cOZIu1tAz9VeBLwCvnusOSfYnWUmysrp6Yi7jJF24DaNOciPwXFU9dL77VdWBqlququWlpR1zGyjpwgw5U+8DPpLkCeBu4Pokd466StLMNoy6qr5YVbuqag9wM/CDqrpl9GWSZuLvU0vNXNDnqavqR8CPRlkiaS48U0vNGLXUjFFLzRi11IxRS814NdEtZqyrfr70+787ynG3Pfc/oxxX5+aZWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxquJChjvqp8LL7w8ynFfufRNoxy3A8/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjODok5yWZKDSX6W5EiS9449TNJshr755GvA96rqL5NcAmwfcZOki7Bh1EkuBd4PfAKgqk4CJ8edJWlWQx5+7wWOA99K8nCS25PsOPNOSfYnWUmysrp6Yu5DJQ0zJOpF4N3AN6rqWuAEcNuZd6qqA1W1XFXLS0tnNS/pdTIk6mPAsap6YP3rg6xFLmkT2jDqqvoV8FSSd67fdANweNRVkmY29NXvTwN3rb/yfRT45HiTJF2MQVFX1SPA8shbJM2B7yiTmjFqqRmjlpoxaqkZo5aa8WqiGtVYV/08tW1h7sdcfOmVuR9zCp6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGCw9qSxrjIoFZHefCg7U0/4skno9naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZQVEn+VySx5L8NMm3k4zzp55JumgbRp1kJ/AZYLmqrgEWgJvHHiZpNkMffi8C25IsAtuBZ8abJOlibBh1VT0NfBl4EngWeKGqvn/m/ZLsT7KSZGV19cT8l0oaZMjD77cCNwF7gbcDO5Lccub9qupAVS1X1fLS0o75L5U0yJCH3x8AfllVx6tqFbgHeN+4syTNakjUTwLXJdmeJMANwJFxZ0ma1ZDn1A8AB4FDwKPrf8+BkXdJmtGgz1NX1ZeAL428RdIc+I4yqRmjlpoxaqkZo5aaMWqpGa8mKq0b66qftZhRjnsunqmlZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWZSVfM/aHIc+I8Bd70C+PXcB4xnK+3dSltha+3dDFvfUVVXvtY3Rol6qCQrVbU82YALtJX2bqWtsLX2bvatPvyWmjFqqZmpo95qf3j9Vtq7lbbC1tq7qbdO+pxa0vxNfaaWNGdGLTUzWdRJPpjk50keT3LbVDs2kmR3kh8mOZzksSS3Tr1piCQLSR5O8t2pt5xPksuSHEzysyRHkrx36k3nk+Rz6z8HP03y7SRvmnrTmSaJOskC8HXgQ8DVwMeSXD3FlgFOAZ+vqquB64C/3cRbT3crcGTqEQN8DfheVf0R8C428eYkO4HPAMtVdQ2wANw87aqzTXWmfg/weFUdraqTwN3ATRNtOa+qeraqDq3/+kXWfuh2Trvq/JLsAj4M3D71lvNJcinwfuAOgKo6WVXPT7tqQ4vAtiSLwHbgmYn3nGWqqHcCT5329TE2eSgASfYA1wIPTLtkQ18FvgC8OvWQDewFjgPfWn+qcHuSHVOPOpeqehr4MvAk8CzwQlV9f9pVZ/OFsoGSvBn4DvDZqvrN1HvOJcmNwHNV9dDUWwZYBN4NfKOqrgVOAJv59ZW3svaIci/wdmBHklumXXW2qaJ+Gth92te71m/blJIssRb0XVV1z9R7NrAP+EiSJ1h7WnN9kjunnXROx4BjVfV/j3wOshb5ZvUB4JdVdbyqVoF7gPdNvOksU0X9E+CqJHuTXMLaiw33TrTlvJKEted8R6rqK1Pv2UhVfbGqdlXVHtb+vf6gqjbd2QSgqn4FPJXknes33QAcnnDSRp4Erkuyff3n4gY24Qt7i1P8Q6vqVJJPAfez9griN6vqsSm2DLAP+DjwaJJH1m/7+6q6b8JNnXwauGv9f+5HgU9OvOecquqBJAeBQ6z9rsjDbMK3jPo2UakZXyiTmjFqqRmjlpoxaqkZo5aaMWqpGaOWmvlfrtFe31cYfuIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "Niter = 200\n",
    "matrix_shape = (10, 10)\n",
    "\n",
    "in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
    "in_shape_1 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
    "in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1])\n",
    "\n",
    "# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
    "\n",
    "bcb = xc.XlaBuilder(\"bodycomp\")\n",
    "intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
    "x = xops.GetTupleElement(intuple, 0)\n",
    "cntr = xops.GetTupleElement(intuple, 1)\n",
    "Q, R = xops.QR(x, True)\n",
    "RQ = xops.Dot(R, Q)\n",
    "xops.Tuple(bcb, [RQ, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
    "body_computation = bcb.Build()\n",
    "\n",
    "# test computation -- just a for loop condition\n",
    "tcb = xc.XlaBuilder(\"testcomp\")\n",
    "intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
    "cntr = xops.GetTupleElement(intuple, 1)\n",
    "test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
    "test_computation = tcb.Build()\n",
    "\n",
    "# while computation:\n",
    "wcb = xc.XlaBuilder(\"whilecomp\")\n",
    "x = xops.Parameter(wcb, 0, in_shape_0)\n",
    "cntr = xops.Parameter(wcb, 1, in_shape_1)\n",
    "tuple_init_carry = xops.Tuple(wcb, [x, cntr])\n",
    "xops.While(test_computation, body_computation, tuple_init_carry)\n",
    "while_computation = wcb.Build()\n",
    "\n",
    "# Now compile and execute:\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(while_computation)\n",
    "\n",
    "X = np.random.random(matrix_shape).astype(np.float32)\n",
    "X = (X + X.T) / 2.0\n",
    "it = np.array(Niter, dtype=np.int32)\n",
    "\n",
    "device_input_x = cpu_backend.buffer_from_pyval(X)\n",
    "device_input_it = cpu_backend.buffer_from_pyval(it)\n",
    "device_out = compiled_computation.execute([device_input_x, device_input_it])\n",
    "\n",
    "host_out = device_out[0].to_py()\n",
    "eigh_vals = host_out.diagonal()\n",
    "\n",
    "plt.title('D')\n",
    "plt.imshow(host_out)\n",
    "print('sorted eigenvalues')\n",
    "print(np.sort(eigh_vals))\n",
    "print('sorted eigenvalues from numpy')\n",
    "print(np.sort(np.linalg.eigh(X)[0]))\n",
    "print('sorted error') \n",
    "print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FpggTihknAOw"
   },
   "source": [
    "## Calculate Full Symm Eigensystem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qos4ankYuj1T"
   },
   "source": [
    "We can also calculate the  eigenbasis by accumulating the Qs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "Kp3A-aAiZk0g",
    "outputId": "bbaff039-20f4-45cd-b8fe-5a664d413f5b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sorted eigenvalues\n",
      "[-0.95164776 -0.5988633  -0.28330874 -0.07402738  0.15438193  0.19796501\n",
      "  0.4779069   0.58893895  0.81445134  4.5762177 ]\n",
      "sorted eigenvalues from numpy\n",
      "[-0.95164794 -0.6314303  -0.28330857 -0.07402731  0.15438198  0.19796519\n",
      "  0.47790694  0.62150407  0.81445104  4.5762167 ]\n",
      "sorted error\n",
      "[ 1.7881393e-07  3.2567024e-02 -1.7881393e-07 -6.7055225e-08\n",
      " -4.4703484e-08 -1.7881393e-07 -2.9802322e-08 -3.2565117e-02\n",
      "  2.9802322e-07  9.5367432e-07]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK5UlEQVR4nO3dT4xdB3mG8efF4yjYoRASVBXbwl4gkJUKBU1RSFQWCaqgpGSD1CAFUTZWpQIBIdHQDdsuEIIFSmUlsElEFiaLCEWESsACITlM/lTBdkBJSB0nRrgtCZFd1x7ydTFT5Nqx5/j6Hp+ZL89PiuSZuTl+lczjc++de49TVUjq401TD5A0X0YtNWPUUjNGLTVj1FIzRi01Y9RSM0b9BpXk+ST/neTVJC8n+VmSv0/i98QG5//AN7a/qaq3AO8C/hn4R+DeaSfpUhm1qKpXquoh4G+BTye5bupNmp1R64+q6lHgCPCXU2/R7IxaZ3sJePvUIzQ7o9bZtgH/NfUIzc6o9UdJ/oKVqH869RbNzqhFkj9JcivwAHBfVT019SbNLr6f+o0pyfPAnwLLwGvAQeA+4F+q6g8TTtMlMmqpGe9+S80YtdSMUUvNGLXUzMIYB7327Ztq547Ncz/ur569Zu7HlDaik//zMqdOH8/rfW2UqHfu2Myjj+yY+3H/6hOfnvsxpY3o0SfvPu/XvPstNWPUUjNGLTVj1FIzRi01Y9RSM4OiTvKRJL9M8kySu8YeJWl2a0adZBPwLeCjwG7gk0l2jz1M0myGnKk/ADxTVc9V1SlW3kh/27izJM1qSNTbgBfO+PjI6uf+nyR7kiwlWTr2n77HXprK3J4oq6q9VbVYVYvvuGbTvA4r6SINifpF4MwXcm9f/ZykdWhI1D8H3p1kV5IrgNuBh8adJWlWa75Lq6qWk3wWeATYBHy7qg6MvkzSTAa99bKqHgYeHnmLpDnwFWVSM0YtNWPUUjNGLTVj1FIzo1x48FfPXjPKRQLzs3+b+zEB6sb3jXJcaQqeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZka5muhYxrrq5+/eu2Xux7z66RNzP6Y0hGdqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqZk1o06yI8mPkxxMciDJnZdjmKTZDHnxyTLwpap6PMlbgMeS/GtVHRx5m6QZrHmmrqqjVfX46q9fBQ4B28YeJmk2F/WYOslO4Hpg/+t8bU+SpSRLp5ePz2edpIs2OOokVwHfA75QVb8/++tVtbeqFqtqcfPC1nlulHQRBkWdZDMrQd9fVQ+OO0nSpRjy7HeAe4FDVfX18SdJuhRDztQ3AZ8Cbk7y5Oo/fz3yLkkzWvNHWlX1UyCXYYukOfAVZVIzRi01Y9RSM0YtNbOhLjw4ljEuErjw7NG5HxPg5J/vGOW4CyeWRzmuLj/P1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM15NdCRjXfXz5LWbRznuVYe9mmgXnqmlZoxaasaopWaMWmrGqKVmjFpqxqilZgZHnWRTkieSfH/MQZIuzcWcqe8EDo01RNJ8DIo6yXbgY8A9486RdKmGnqm/AXwZeO18N0iyJ8lSkqXTy8fnMk7SxVsz6iS3Ar+tqscudLuq2ltVi1W1uHlh69wGSro4Q87UNwEfT/I88ABwc5L7Rl0laWZrRl1VX6mq7VW1E7gd+FFV3TH6Mkkz8efUUjMX9X7qqvoJ8JNRlkiaC8/UUjNGLTVj1FIzRi01Y9RSM15NdCQLJ8a5OudYV/088WdXjnLcLUdPjnJcnZ9naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGa8mKmC8q34u/O7EKMddvnrLKMftwDO11IxRS80YtdSMUUvNGLXUjFFLzRi11MygqJO8Lcm+JE8nOZTkg2MPkzSboS8++Sbwg6r6RJIrAH/yL61Ta0ad5K3Ah4C/A6iqU8CpcWdJmtWQu9+7gGPAd5I8keSeJFvPvlGSPUmWkiydXj4+96GShhkS9QLwfuDuqroeOA7cdfaNqmpvVS1W1eLmhXOal3SZDIn6CHCkqvavfryPlcglrUNrRl1VvwFeSPKe1U/dAhwcdZWkmQ199vtzwP2rz3w/B3xmvEmSLsWgqKvqSWBx5C2S5sBXlEnNGLXUjFFLzRi11IxRS814NVGNaqyrfi5vmf+37sKJ5bkfcwqeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxgsPakMa4yKBf7hynBw2nby8FzT0TC01Y9RSM0YtNWPUUjNGLTVj1FIzRi01MyjqJF9MciDJL5J8N8mVYw+TNJs1o06yDfg8sFhV1wGbgNvHHiZpNkPvfi8Ab06yAGwBXhpvkqRLsWbUVfUi8DXgMHAUeKWqfnj27ZLsSbKUZOn08vH5L5U0yJC731cDtwG7gHcCW5PccfbtqmpvVS1W1eLmha3zXyppkCF3vz8M/LqqjlXVaeBB4MZxZ0ma1ZCoDwM3JNmSJMAtwKFxZ0ma1ZDH1PuBfcDjwFOr/87ekXdJmtGgN5BW1VeBr468RdIc+IoyqRmjlpoxaqkZo5aaMWqpGa8mKq0a66qfbzr92vwPWhf4/eb/u0maklFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11EyqLnBZwlkPmhwD/n3ATa8F/mPuA8azkfZupK2wsfauh63vqqp3vN4XRol6qCRLVbU42YCLtJH2bqStsLH2rvet3v2WmjFqqZmpo95of3n9Rtq7kbbCxtq7rrdO+pha0vxNfaaWNGdGLTUzWdRJPpLkl0meSXLXVDvWkmRHkh8nOZjkQJI7p940RJJNSZ5I8v2pt1xIkrcl2Zfk6SSHknxw6k0XkuSLq98Hv0jy3SRXTr3pbJNEnWQT8C3go8Bu4JNJdk+xZYBl4EtVtRu4AfiHdbz1THcCh6YeMcA3gR9U1XuB97GONyfZBnweWKyq64BNwO3TrjrXVGfqDwDPVNVzVXUKeAC4baItF1RVR6vq8dVfv8rKN922aVddWJLtwMeAe6beciFJ3gp8CLgXoKpOVdXL065a0wLw5iQLwBbgpYn3nGOqqLcBL5zx8RHWeSgASXYC1wP7p12ypm8AXwZG+NvO52oXcAz4zupDhXuSbJ161PlU1YvA14DDwFHglar64bSrzuUTZQMluQr4HvCFqvr91HvOJ8mtwG+r6rGptwywALwfuLuqrgeOA+v5+ZWrWblHuQt4J7A1yR3TrjrXVFG/COw44+Ptq59bl5JsZiXo+6vqwan3rOEm4ONJnmflYc3NSe6bdtJ5HQGOVNX/3fPZx0rk69WHgV9X1bGqOg08CNw48aZzTBX1z4F3J9mV5ApWnmx4aKItF5QkrDzmO1RVX596z1qq6itVtb2qdrLy3/VHVbXuziYAVfUb4IUk71n91C3AwQknreUwcEOSLavfF7ewDp/YW5jiN62q5SSfBR5h5RnEb1fVgSm2DHAT8CngqSRPrn7un6rq4Qk3dfI54P7VP9yfAz4z8Z7zqqr9SfYBj7PyU5EnWIcvGfVlolIzPlEmNWPUUjNGLTVj1FIzRi01Y9RSM0YtNfO/5H9lzlLkxRMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM+klEQVR4nO3dfWyd9XnG8evCjomdhMQMNEoSiCfaQqi0BVkoNFpVEba1g5Vpa7dUoxPVtFRTX0LLWsE0jQrtv6KKSo3YokA1CVRYUyahDhG2tZWGVGWYwASJCSIv5JUl0JJASHCc3PvDnpQlOH5s/3489p3vR0KKzznc3DL5+jnn+PFjR4QA5HFB2wsAKIuogWSIGkiGqIFkiBpIhqiBZIgaSIaoz1O2w/ZVZ9z2bdsPt7UTyiBqIBmiBpIhaiAZogaSIerz10lJs864bZakEy3sgoKI+vy1W9KSM27rk/TaB78KSiLq89djkv7W9iLbF9i+SdIfSNrQ8l6YIvPz1Ocn292S7pX0OUm9krZL+nZEPNHqYpgyogaS4ek3kAxRA8kQNZAMUQPJdNYYOmt+d8y+bH7xuXHozHMlyjgxr/ybhbO765zDMTTcUWVu1/4qY3Xywjr76tLyn9+OXaeKz5Sk44u6is8cPvQrnXz7qN/vvipRz75svvof+LPic4//44eKz5Sk/SvLR33t1XuKz5SknW9eXGXuFfecrDL3nY+U/+IuSf7SweIz591+rPhMSRq89/LiMw/83dox7+PpN5AMUQPJEDWQDFEDyRA1kAxRA8k0itr2p2xvs/2q7btqLwVg8saN2naHpLWSPi1pqaTP215aezEAk9PkSH29pFcjYkdEDEl6VNKtddcCMFlNol4o6fTTo/aO3vb/2F5te8D2wIm33i21H4AJKvZGWUSsi4j+iOiftaCn1FgAE9Qk6n2SFp/28aLR2wBMQ02iflbSh2332e6StEoS17ECpqlxf0orIoZtf0XSRkkdkh6KiC3VNwMwKY1+9DIinpT0ZOVdABTAGWVAMkQNJEPUQDJEDSRD1EAyVS48OKdjSNddXP7Ce7d/57HiMyVp1do7i8/cOnRF8ZmS5OH3vYDklF3x0H9Xmbv7k3X2ffn3ri0+s3NNlRzUNfud4jPtsS+WyZEaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimyuUTj52apcEjlxWf+yc/Kn/VT0n6nVX/VXzm4F9dU3ymJG3/0zlV5nac4+qUU9H91Nwqc3sOvVt8Zu/TdXa9+7M/Kj7zGz2/HPM+jtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMuNGbXux7Z/Z3mp7i+01H8RiACanycknw5LujIjNtudJes72v0XE1sq7AZiEcY/UEXEgIjaP/vltSYOSFtZeDMDkTOg1te0lkpZJ2vQ+9622PWB7YOitY2W2AzBhjaO2PVfSjyXdERFHzrw/ItZFRH9E9Hct6C65I4AJaBS17VkaCfqRiHi87koApqLJu9+W9KCkwYj4bv2VAExFkyP1CklfkHSj7RdG//n9ynsBmKRxv6UVEc9I8gewC4ACOKMMSIaogWSIGkiGqIFkqlx4cG7He1p+8c7icw9vu6L4TEna+MT1xWd2frL4SElS9+t15j6zr6/K3GOvLKgy9z9Wfaf4zL9csKr4TEn6/h/9YfGZB7c/OOZ9HKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSqXE10OC7QGyfmFp+793ej+ExJumRT+blv/ladXXd87h+qzF32bJ0rad7w21uqzF31zb8uPvO9+XV+u1TPVaeKzzz5WseY93GkBpIhaiAZogaSIWogGaIGkiFqIBmiBpJpHLXtDtvP2/5JzYUATM1EjtRrJA3WWgRAGY2itr1I0s2S1tddB8BUNT1S3y/pW5LGPN/N9mrbA7YHjv3qvSLLAZi4caO2fYukgxHx3LkeFxHrIqI/Ivq7ey8stiCAiWlypF4h6TO2d0l6VNKNth+uuhWASRs36oi4OyIWRcQSSask/TQibqu+GYBJ4fvUQDIT+nnqiPi5pJ9X2QRAERypgWSIGkiGqIFkiBpIhqiBZKpcTfS9U53afbS3+Nylf7+v+ExJ2nrvh4rP7L7oePGZktS38S+qzL1wzlCVua+sXVpl7oLt7xSfefM//WfxmZK08VD5z0HHS2P//+JIDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kU+VqohHW8ZOzis89sr7O772+9OTh8jO/9G7xmZK05/sXVZm76M/rXKl1+zevrTL3yJLyn4dnj1xZfKYkbdu0pPjM4+90jXkfR2ogGaIGkiFqIBmiBpIhaiAZogaSIWogmUZR215ge4Ptl20P2r6h9mIAJqfpySffk/RURHzWdpeknoo7AZiCcaO2PV/SJyTdLkkRMSSpzi8zBjBlTZ5+90k6JOkHtp+3vd72nDMfZHu17QHbA0OHjxVfFEAzTaLulHSdpAciYpmko5LuOvNBEbEuIvojor9rfnfhNQE01STqvZL2RsSm0Y83aCRyANPQuFFHxOuS9tj+6OhNKyVtrboVgElr+u73VyU9MvrO9w5JX6y3EoCpaBR1RLwgqb/yLgAK4IwyIBmiBpIhaiAZogaSIWogmSpXE53dcUIfuehg8bkHj88rPlOSlvS8WXzmHb/4RfGZkrT8X75RZe6uNb1V5v7xLc9Umfuvr5W/Sukbx+cWnylJvvJo+ZkXnhrzPo7UQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRT5cKDR4e7NHDoiuJz9+28pPhMSVqyvPyFB2tdIPCa+/+nytzBO3+9ytx//vcVVeZe9djbxWce7ltcfKYk/caud4vPPPj62PdxpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSaRS17a/b3mL7Jds/tD279mIAJmfcqG0vlPQ1Sf0R8TFJHZJW1V4MwOQ0ffrdKanbdqekHkn7660EYCrGjToi9km6T9JuSQckHY6Ip898nO3VtgdsD5w4fKz8pgAaafL0u1fSrZL6JF0uaY7t2858XESsi4j+iOifNb+7/KYAGmny9PsmSTsj4lBEnJD0uKSP110LwGQ1iXq3pOW2e2xb0kpJg3XXAjBZTV5Tb5K0QdJmSS+O/jvrKu8FYJIa/Tx1RNwj6Z7KuwAogDPKgGSIGkiGqIFkiBpIhqiBZKpcTXR2x7Cu6S1/1cv9r1xafKYkPbes/Ne2C+5z8ZmSdPLX5lWZe/ELdb6+X3X7tipzX/zl1cVnDvVG8ZmS1PvlN4rP9OrhMe/jSA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJOOI8ldQtH1I0msNHnqJpPKXWqxnJu07k3aVZta+02HXKyPifS+vWyXqpmwPRER/awtM0EzadybtKs2sfaf7rjz9BpIhaiCZtqOeab+8fibtO5N2lWbWvtN611ZfUwMor+0jNYDCiBpIprWobX/K9jbbr9q+q609xmN7se2f2d5qe4vtNW3v1ITtDtvP2/5J27uci+0FtjfYftn2oO0b2t7pXGx/ffTvwUu2f2h7dts7namVqG13SFor6dOSlkr6vO2lbezSwLCkOyNiqaTlkr48jXc93RpJg20v0cD3JD0VEVdL+k1N451tL5T0NUn9EfExSR2SVrW71dnaOlJfL+nViNgREUOSHpV0a0u7nFNEHIiIzaN/flsjf+kWtrvVudleJOlmSevb3uVcbM+X9AlJD0pSRAxFxFvtbjWuTkndtjsl9Uja3/I+Z2kr6oWS9pz28V5N81AkyfYSScskbWp3k3HdL+lbkk61vcg4+iQdkvSD0ZcK623PaXupsUTEPkn3Sdot6YCkwxHxdLtbnY03yhqyPVfSjyXdERFH2t5nLLZvkXQwIp5re5cGOiVdJ+mBiFgm6aik6fz+Sq9GnlH2Sbpc0hzbt7W71dnainqfpMWnfbxo9LZpyfYsjQT9SEQ83vY+41gh6TO2d2nkZc2Nth9ud6Ux7ZW0NyL+75nPBo1EPl3dJGlnRByKiBOSHpf08ZZ3OktbUT8r6cO2+2x3aeTNhida2uWcbFsjr/kGI+K7be8znoi4OyIWRcQSjXxefxoR0+5oIkkR8bqkPbY/OnrTSklbW1xpPLslLbfdM/r3YqWm4Rt7nW38RyNi2PZXJG3UyDuID0XEljZ2aWCFpC9IetH2C6O3/U1EPNniTpl8VdIjo1/cd0j6Ysv7jCkiNtneIGmzRr4r8rym4SmjnCYKJMMbZUAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAy/wvI0tLk+VKqNgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAALx0lEQVR4nO3dX4xc9XmH8eeL1whsqoRCKjW2VbttSoRSpURbREDiAmiVlDRILRcgQDQ3qGqTkDRSRKpKVL2soii5SGktkvQCFFIZlCKKApWSXESRHMyfKmATxJ/UGEzjVOSfXWMvvL3YieTaeHe8nsPZfft8pJV2zgxnXlv7cM7MnP05VYWkPs4YewBJs2XUUjNGLTVj1FIzRi01Y9RSM0bdTJLzk9ye5LKxZ9E4jHoNSFJJfvu4bX+b5K7jtm0E/g34A+CBJBcdc98NSX4x+fqfJG8cc/sXSzx3kjyfZPes5tSwjLqJJOuBe4HdwOXAnwP3J/ktgKq6u6rOqapzgA8CL//y9mTbyVwO/Brwm0l+f9g/hWZhbuwBdPqSBPhn4AXgL2rxMsGvJXmNxbCvqKr/WuHubwb+FTh78v0jMxhZAzLqBiYR3/Am278OfH2l+02yAbgWuI7FqP8pyV9V1ZGV7lPD8/RbS/kT4DXgYRZfq68Hrh51Ii3LqNeG11kM6ljrgaMDP+/NwL9U1UJVHWbxNfvNSzx+rDl1DE+/14a9wFZgzzHbtgHPDPWESTYDVwAXJ/nTyeYNwFlJzq+qH6+GOXUij9Rrw9eAv0myOckZSa4C/hjYMeBz3sRijBcAvzf5+h1gH3D9KppTxzHqteHvgO8C3wFeBf4euKGqnhzwOW8G/qGqXjn2C/hHTn4KPsacOk5cJEHqxSO11IxRS80YtdSMUUvNDPI59fm/uq62bjn+GoTT98xz5818n9JadPi1n3Dk6MG82X2DRL11y3q+99CWme/3D69d6mIm6f+P7z1xx0nv8/RbasaopWaMWmrGqKVmjFpqxqilZqaKOskHkvwgybNJbht6KEkrt2zUSdYBX2RxBcoLgeuTXDj0YJJWZpoj9cXAs1X1/GTBuXuAa4YdS9JKTRP1JuDFY27vm2z7P5LckmRXkl0H/vv1Wc0n6RTN7I2yqtpeVfNVNf+O89bNareSTtE0Ub8EHHsh9+bJNkmr0DRRPwK8K8m2JGeyuLD7/cOOJWmllv0trapaSPJR4CFgHfDlqnpq8MkkrchUv3pZVQ8CDw48i6QZ8IoyqRmjlpoxaqkZo5aaMWqpmUEWHnzmufMGWSQw3/2Pme8ToC597yD7lcbgkVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaamaQ1USHMtSqn6++e8PM93nu04dmvk9pGh6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaWjTrJliTfSrI7yVNJbn0rBpO0MtNcfLIAfKqqHkvyK8CjSf69qnYPPJukFVj2SF1V+6vqscn3Pwf2AJuGHkzSypzSa+okW4GLgJ1vct8tSXYl2XV04eBsppN0yqaOOsk5wL3AJ6rqZ8ffX1Xbq2q+qubXz22c5YySTsFUUSdZz2LQd1fVfcOOJOl0TPPud4AvAXuq6nPDjyTpdExzpL4MuAm4IskTk68/GnguSSu07EdaVfUdIG/BLJJmwCvKpGaMWmrGqKVmjFpqZk0tPDiUIRYJnHtu/8z3CXD4d7cMst+5QwuD7FdvPY/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzriY6kKFW/Tx8/vpB9nvOXlcT7cIjtdSMUUvNGLXUjFFLzRi11IxRS80YtdTM1FEnWZfk8SQPDDmQpNNzKkfqW4E9Qw0iaTamijrJZuBq4M5hx5F0uqY9Un8e+DTwxskekOSWJLuS7Dq6cHAmw0k6dctGneRDwI+q6tGlHldV26tqvqrm189tnNmAkk7NNEfqy4APJ/khcA9wRZK7Bp1K0ootG3VVfaaqNlfVVuA64JtVdePgk0laET+nlpo5pd+nrqpvA98eZBJJM+GRWmrGqKVmjFpqxqilZoxaasbVRAcyd2iY1TmHWvXz0K+fNch+N+w/PMh+dXIeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZlxNVMBwq37OvXpokP0unLthkP124JFaasaopWaMWmrGqKVmjFpqxqilZoxaamaqqJO8PcmOJE8n2ZPk/UMPJmllpr345AvAN6rq2iRnAn7yL61Sy0ad5G3A5cCfAVTVEeDIsGNJWqlpTr+3AQeAryR5PMmdSTYe/6AktyTZlWTX0YWDMx9U0nSmiXoOeB9wR1VdBBwEbjv+QVW1varmq2p+/dwJzUt6i0wT9T5gX1XtnNzewWLkklahZaOuqleAF5NcMNl0JbB70Kkkrdi0735/DLh78s7388BHhhtJ0umYKuqqegKYH3gWSTPgFWVSM0YtNWPUUjNGLTVj1FIzriaqQQ216ufChtn/6M4dWpj5PsfgkVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZlx4UGvSEIsEvn7WMDmsO/zWLmjokVpqxqilZoxaasaopWaMWmrGqKVmjFpqZqqok3wyyVNJnkzy1SRnDT2YpJVZNuokm4CPA/NV9R5gHXDd0INJWplpT7/ngLOTzAEbgJeHG0nS6Vg26qp6CfgssBfYD/y0qh4+/nFJbkmyK8muowsHZz+ppKlMc/p9LnANsA14J7AxyY3HP66qtlfVfFXNr5/bOPtJJU1lmtPvq4AXqupAVR0F7gMuHXYsSSs1TdR7gUuSbEgS4Epgz7BjSVqpaV5T7wR2AI8B35/8N9sHnkvSCk31C6RVdTtw+8CzSJoBryiTmjFqqRmjlpoxaqkZo5aacTVRaWKoVT/POPrG7HdaSzzf7J9N0piMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmUrXEsoQr3WlyAPjPKR56PvDjmQ8wnLU071qaFdbWvKth1t+oqne82R2DRD2tJLuqan60AU7RWpp3Lc0Ka2ve1T6rp99SM0YtNTN21GvtH69fS/OupVlhbc27qmcd9TW1pNkb+0gtacaMWmpmtKiTfCDJD5I8m+S2seZYTpItSb6VZHeSp5LcOvZM00iyLsnjSR4Ye5alJHl7kh1Jnk6yJ8n7x55pKUk+Ofk5eDLJV5OcNfZMxxsl6iTrgC8CHwQuBK5PcuEYs0xhAfhUVV0IXAL85Sqe9Vi3AnvGHmIKXwC+UVXvBt7LKp45ySbg48B8Vb0HWAdcN+5UJxrrSH0x8GxVPV9VR4B7gGtGmmVJVbW/qh6bfP9zFn/oNo071dKSbAauBu4ce5alJHkbcDnwJYCqOlJVPxl3qmXNAWcnmQM2AC+PPM8Jxop6E/DiMbf3scpDAUiyFbgI2DnuJMv6PPBpYIB/7XymtgEHgK9MXircmWTj2EOdTFW9BHwW2AvsB35aVQ+PO9WJfKNsSknOAe4FPlFVPxt7npNJ8iHgR1X16NizTGEOeB9wR1VdBBwEVvP7K+eyeEa5DXgnsDHJjeNOdaKxon4J2HLM7c2TbatSkvUsBn13Vd039jzLuAz4cJIfsviy5ookd4070kntA/ZV1S/PfHawGPlqdRXwQlUdqKqjwH3ApSPPdIKxon4EeFeSbUnOZPHNhvtHmmVJScLia749VfW5sedZTlV9pqo2V9VWFv9ev1lVq+5oAlBVrwAvJrlgsulKYPeIIy1nL3BJkg2Tn4srWYVv7M2N8aRVtZDko8BDLL6D+OWqemqMWaZwGXAT8P0kT0y2/XVVPTjiTJ18DLh78j/354GPjDzPSVXVziQ7gMdY/FTkcVbhJaNeJio14xtlUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjP/C52Cw8vnl5XSAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "Niter = 100\n",
    "matrix_shape = (10, 10)\n",
    "\n",
    "in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
    "in_shape_1 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
    "in_shape_2 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
    "in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1, in_shape_2])\n",
    "\n",
    "# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
    "bcb = xc.XlaBuilder(\"bodycomp\")\n",
    "intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
    "X = xops.GetTupleElement(intuple, 0)\n",
    "O = xops.GetTupleElement(intuple, 1)\n",
    "cntr = xops.GetTupleElement(intuple, 2)\n",
    "Q, R = xops.QR(X, True)\n",
    "RQ = xops.Dot(R, Q)\n",
    "Onew = xops.Dot(O, Q)\n",
    "xops.Tuple(bcb, [RQ, Onew, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
    "body_computation = bcb.Build()\n",
    "\n",
    "# test computation -- just a for loop condition\n",
    "tcb = xc.XlaBuilder(\"testcomp\")\n",
    "intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
    "cntr = xops.GetTupleElement(intuple, 2)\n",
    "test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
    "test_computation = tcb.Build()\n",
    "\n",
    "# while computation:\n",
    "wcb = xc.XlaBuilder(\"whilecomp\")\n",
    "X = xops.Parameter(wcb, 0, in_shape_0)\n",
    "O = xops.Parameter(wcb, 1, in_shape_1)\n",
    "cntr = xops.Parameter(wcb, 2, in_shape_2)\n",
    "tuple_init_carry = xops.Tuple(wcb, [X, O, cntr])\n",
    "xops.While(test_computation, body_computation, tuple_init_carry)\n",
    "while_computation = wcb.Build()\n",
    "\n",
    "# Now compile and execute:\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(while_computation)\n",
    "\n",
    "X = np.random.random(matrix_shape).astype(np.float32)\n",
    "X = (X + X.T) / 2.0\n",
    "Omat = np.eye(matrix_shape[0], dtype=np.float32)\n",
    "it = np.array(Niter, dtype=np.int32)\n",
    "\n",
    "device_input_X = cpu_backend.buffer_from_pyval(X)\n",
    "device_input_Omat = cpu_backend.buffer_from_pyval(Omat)\n",
    "device_input_it = cpu_backend.buffer_from_pyval(it)\n",
    "device_out = compiled_computation.execute([device_input_X, device_input_Omat, device_input_it])\n",
    "\n",
    "host_out = device_out[0].to_py()\n",
    "eigh_vals = host_out.diagonal()\n",
    "eigh_mat = device_out[1].to_py()\n",
    "\n",
    "plt.title('D')\n",
    "plt.imshow(host_out)\n",
    "plt.figure()\n",
    "plt.title('U')\n",
    "plt.imshow(eigh_mat)\n",
    "plt.figure()\n",
    "plt.title('U^T A U')\n",
    "plt.imshow(np.dot(np.dot(eigh_mat.T, X), eigh_mat))\n",
    "print('sorted eigenvalues')\n",
    "print(np.sort(eigh_vals))\n",
    "print('sorted eigenvalues from numpy')\n",
    "print(np.sort(np.linalg.eigh(X)[0]))\n",
    "print('sorted error') \n",
    "print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ee3LMzOvlCuK"
   },
   "source": [
    "## Convolutions\n",
    "\n",
    "I keep hearing from the AGI folks that we can use convolutions to build artificial life.  Let's try it out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "9xh6yeXKS9Vg"
   },
   "outputs": [],
   "source": [
    "# Here we borrow convenience functions from LAX to handle conv dimension numbers.\n",
    "from typing import NamedTuple, Sequence\n",
    "\n",
    "class ConvDimensionNumbers(NamedTuple):\n",
    "  \"\"\"Describes batch, spatial, and feature dimensions of a convolution.\n",
    "\n",
    "  Args:\n",
    "    lhs_spec: a tuple of nonnegative integer dimension numbers containing\n",
    "      `(batch dimension, feature dimension, spatial dimensions...)`.\n",
    "    rhs_spec: a tuple of nonnegative integer dimension numbers containing\n",
    "      `(out feature dimension, in feature dimension, spatial dimensions...)`.\n",
    "    out_spec: a tuple of nonnegative integer dimension numbers containing\n",
    "      `(batch dimension, feature dimension, spatial dimensions...)`.\n",
    "  \"\"\"\n",
    "  lhs_spec: Sequence[int]\n",
    "  rhs_spec: Sequence[int]\n",
    "  out_spec: Sequence[int]\n",
    "\n",
    "def _conv_general_proto(dimension_numbers):\n",
    "  assert type(dimension_numbers) is ConvDimensionNumbers\n",
    "  lhs_spec, rhs_spec, out_spec = dimension_numbers\n",
    "  proto = xc.ConvolutionDimensionNumbers()\n",
    "  proto.input_batch_dimension = lhs_spec[0]\n",
    "  proto.input_feature_dimension = lhs_spec[1]\n",
    "  proto.output_batch_dimension = out_spec[0]\n",
    "  proto.output_feature_dimension = out_spec[1]\n",
    "  proto.kernel_output_feature_dimension = rhs_spec[0]\n",
    "  proto.kernel_input_feature_dimension = rhs_spec[1]\n",
    "  proto.input_spatial_dimensions.extend(lhs_spec[2:])\n",
    "  proto.kernel_spatial_dimensions.extend(rhs_spec[2:])\n",
    "  proto.output_spatial_dimensions.extend(out_spec[2:])\n",
    "  return proto"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "J8QkirDalBse",
    "outputId": "543a03fd-f038-46f2-9a76-a6532b86874e"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABEYAAABdCAYAAACo5mNeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAADjUlEQVR4nO3dwW0TURRAUTtKFVRBE4gKqJIKEE1QBWUwWSEFL4ideTMT+56zs7Kw9fRnc/Xn5bwsywkAAACg6OnoHwAAAABwFGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAICs5//98cvTN/+yZoWff76fX382z3Vez9Ms1zHLOZ7zWc7mHLOc4zmf5WzOcTZnOZtzzHKO53zW5Tz/cmMEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyHo+8st//P71z+evnz4f9EsAAACAIjdGAAAAgCxhBAAAAMgSRgAAAICsQ3eMXO4UsXPkNm/NyzznmOX7md22zBcAANZxYwQAAADIEkYAAACALGEEAAAAyNp1x8jlu/Cs89ZOEa5nT8N2bj2nZn8bu4Xez56m/ZjlOua3HbMF4HRyYwQAAAAIE0YAAACALGEEAAAAyNp1x4j3Nvdl3tezr2U7dorMcjbneO7n2NOwLbuatmO30Dp2Ne3HLN/P7Lb1KPN1YwQAAADIEkYAAACALGEEAAAAyNp1xwjbutf3uT4is5xjlrPMcz9mfT37WrZlp8gcZ3OWZ3/Oo+xp+IjsadrWo+wWcmMEAAAAyBJGAAAAgCxhBAAAAMiyYwQATvfzDuw9MMtZ5jnHLPdl3tezr2U7dorMetSz6cYIAAAAkCWMAAAAAFnCCAAAAJBlxwgAALCaXQ1zzHKOWc561Hm6MQIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGSdl2U5+jcAAAAAHMKNEQAAACBLGAEAAACyhBEAAAAgSxgBAAAAsoQRAAAAIEsYAQAAALJeAPxyuoaRade9AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1080x144 with 13 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "Niter=13\n",
    "matrix_shape = (1, 1, 20, 20)\n",
    "in_shape_0 = xc.Shape.array_shape(np.dtype(np.int32), matrix_shape)\n",
    "in_shape_1 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
    "in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1])\n",
    "\n",
    "# Body computation -- Conway Update\n",
    "bcb = xc.XlaBuilder(\"bodycomp\")\n",
    "intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
    "x = xops.GetTupleElement(intuple, 0)\n",
    "cntr = xops.GetTupleElement(intuple, 1)\n",
    "# convs require floating-point type\n",
    "xf = xops.ConvertElementType(x, xc.DTYPE_TO_XLA_ELEMENT_TYPE['float32'])\n",
    "stamp = xops.Constant(bcb, np.ones((1,1,3,3), dtype=np.float32))\n",
    "conv_dim_num_proto = _conv_general_proto(ConvDimensionNumbers(lhs_spec=(0,1,2,3), rhs_spec=(0,1,2,3), out_spec=(0,1,2,3)))\n",
    "convd = xops.ConvGeneralDilated(xf, stamp, [1, 1], [(1, 1), (1, 1)], (), (), conv_dim_num_proto)\n",
    "# # logic ops require integer types\n",
    "convd = xops.ConvertElementType(convd, xc.DTYPE_TO_XLA_ELEMENT_TYPE['int32'])\n",
    "bool_x = xops.Eq(x, xops.Constant(bcb, np.int32(1)))\n",
    "# core update rule\n",
    "res = xops.Or(\n",
    "    # birth rule\n",
    "    xops.And(xops.Not(bool_x), xops.Eq(convd, xops.Constant(bcb, np.int32(3)))),\n",
    "    # survival rule\n",
    "    xops.And(bool_x, xops.Or(\n",
    "        # these are +1 the normal numbers since conv-sum counts self\n",
    "        xops.Eq(convd, xops.Constant(bcb, np.int32(4))),\n",
    "        xops.Eq(convd, xops.Constant(bcb, np.int32(3))))\n",
    "    )\n",
    ")\n",
    "# Convert output back to int type for type constancy\n",
    "int_res = xops.ConvertElementType(res, xc.DTYPE_TO_XLA_ELEMENT_TYPE['int32'])\n",
    "xops.Tuple(bcb, [int_res, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
    "body_computation = bcb.Build()\n",
    "\n",
    "# Test computation -- just a for loop condition\n",
    "tcb = xc.XlaBuilder(\"testcomp\")\n",
    "intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
    "cntr = xops.GetTupleElement(intuple, 1)\n",
    "test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
    "test_computation = tcb.Build()\n",
    "\n",
    "# While computation:\n",
    "wcb = xc.XlaBuilder(\"whilecomp\")\n",
    "x = xops.Parameter(wcb, 0, in_shape_0)\n",
    "cntr = xops.Parameter(wcb, 1, in_shape_1)\n",
    "tuple_init_carry = xops.Tuple(wcb, [x, cntr])\n",
    "xops.While(test_computation, body_computation, tuple_init_carry)\n",
    "while_computation = wcb.Build()\n",
    "\n",
    "# Now compile and execute:\n",
    "cpu_backend = xc.get_local_backend(\"cpu\")\n",
    "\n",
    "# compile graph based on shape\n",
    "compiled_computation = cpu_backend.compile(while_computation)\n",
    "\n",
    "# Set up initial state\n",
    "X = np.zeros(matrix_shape, dtype=np.int32)\n",
    "X[0,0, 5:8, 5:8] = np.array([[0,1,0],[0,0,1],[1,1,1]])\n",
    "\n",
    "# Evolve\n",
    "movie = np.zeros((Niter,)+matrix_shape[-2:], dtype=np.int32)\n",
    "for it in range(Niter):\n",
    "  itr = np.array(it, dtype=np.int32)\n",
    "  device_input_x = cpu_backend.buffer_from_pyval(X)\n",
    "  device_input_it = cpu_backend.buffer_from_pyval(itr)\n",
    "  device_out = compiled_computation.execute([device_input_x, device_input_it])\n",
    "  movie[it] = device_out[0].to_py()[0,0]\n",
    "\n",
    "# Plot\n",
    "fig = plt.figure(figsize=(15,2))\n",
    "gs = gridspec.GridSpec(1,Niter)\n",
    "for i in range(Niter):\n",
    "  ax1 = plt.subplot(gs[:, i])\n",
    "  ax1.axis('off')\n",
    "  ax1.imshow(movie[i])\n",
    "plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, hspace=0.0, wspace=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9-0PJlqv237S"
   },
   "source": [
    "## Fin \n",
    "\n",
    "There's much more to XLA, but this hopefully highlights how easy it is to play with via the python client!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "XLA in Python.ipnb",
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
